{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example - Simple Vertically Partitioned Split Neural Network\n",
    "\n",
    "- <b>Alice</b>\n",
    "    - Has model Segment 1\n",
    "    - Has the handwritten Images\n",
    "- <b>Bob</b>\n",
    "    - Has model Segment 2\n",
    "    - Has the image Labels\n",
    "    \n",
    "Based on [SplitNN - Tutorial 3](https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/advanced/split_neural_network/Tutorial%203%20-%20Folded%20Split%20Neural%20Network.ipynb) from Adam J Hall - Twitter: [@AJH4LL](https://twitter.com/AJH4LL) · GitHub:  [@H4LL](https://github.com/H4LL)\n",
    "\n",
    "Authors:\n",
    "- Pavlos Papadopoulos · GitHub:  [@pavlos-p](https://github.com/pavlos-p)\n",
    "- Tom Titcombe · GitHub:  [@TTitcombe](https://github.com/TTitcombe)\n",
    "- Robert Sandmann · GitHub: [@rsandmann](https://github.com/rsandmann)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SplitNN:\n",
    "    def __init__(self, models, optimizers):\n",
    "        self.models = models\n",
    "        self.optimizers = optimizers\n",
    "\n",
    "        self.data = []\n",
    "        self.remote_tensors = []\n",
    "\n",
    "    def forward(self, x):\n",
    "        data = []\n",
    "        remote_tensors = []\n",
    "\n",
    "        data.append(self.models[0](x))\n",
    "\n",
    "        if data[-1].location == self.models[1].location:\n",
    "            remote_tensors.append(data[-1].detach().requires_grad_())\n",
    "        else:\n",
    "            remote_tensors.append(\n",
    "                data[-1].detach().move(self.models[1].location).requires_grad_()\n",
    "            )\n",
    "\n",
    "        i = 1\n",
    "        while i < (len(models) - 1):\n",
    "            data.append(self.models[i](remote_tensors[-1]))\n",
    "\n",
    "            if data[-1].location == self.models[i + 1].location:\n",
    "                remote_tensors.append(data[-1].detach().requires_grad_())\n",
    "            else:\n",
    "                remote_tensors.append(\n",
    "                    data[-1].detach().move(self.models[i + 1].location).requires_grad_()\n",
    "                )\n",
    "\n",
    "            i += 1\n",
    "\n",
    "        data.append(self.models[i](remote_tensors[-1]))\n",
    "\n",
    "        self.data = data\n",
    "        self.remote_tensors = remote_tensors\n",
    "\n",
    "        return data[-1]\n",
    "\n",
    "    def backward(self):\n",
    "        for i in range(len(models) - 2, -1, -1):\n",
    "            if self.remote_tensors[i].location == self.data[i].location:\n",
    "                grads = self.remote_tensors[i].grad.copy()\n",
    "            else:\n",
    "                grads = self.remote_tensors[i].grad.copy().move(self.data[i].location)\n",
    "    \n",
    "            self.data[i].backward(grads)\n",
    "\n",
    "    def zero_grads(self):\n",
    "        for opt in self.optimizers:\n",
    "            opt.zero_grad()\n",
    "\n",
    "    def step(self):\n",
    "        for opt in self.optimizers:\n",
    "            opt.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "import torch\n",
    "from torchvision import datasets, transforms\n",
    "from torch import nn, optim\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision.transforms import ToTensor\n",
    "\n",
    "import syft as sy\n",
    "\n",
    "from src.dataloader import VerticalDataLoader\n",
    "from src.psi.util import Client, Server\n",
    "from src.utils import add_ids\n",
    "\n",
    "hook = sy.TorchHook(torch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dataset\n",
    "data = add_ids(MNIST)(\".\", download=True, transform=ToTensor())  # add_ids adds unique IDs to data points\n",
    "\n",
    "# Batch data\n",
    "dataloader = VerticalDataLoader(data, batch_size=128) # partition_dataset uses by default \"remove_data=True, keep_order=False\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check if the datasets are unordered\n",
    "In MNIST, we have 2 datasets (the images and the labels)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 4 0 7 0 1 5 1 5 1 "
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAAqCAYAAAAQ2Ih6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdEElEQVR4nO2da1SU5drHfwMMgiAHQQ4CchBREFEsT6ipgBstksxjpa7C1Kxdu7Vq1W7b3n2otVud3a2ysrQy3WaZJCIqqKCYynkIBTkoMHIQGIEZBuYE87wfXD4rEg/oDPW+7/Nba77MPDPXn5nn+T/Xfd3XfSMTBAEJCQkJicHB5o8WICEhIfH/Ccl0JSQkJAYRyXQlJCQkBhHJdCUkJCQGEcl0JSQkJAYRu9u8/ke0Nsj6eU7S0RdJR18kHTfyZ9Ei6fgdUqYrISEhMYhIpishISExiNyuvPC/loaGBr788kt6e3t59tln8fX1/aMlSUhI3CEdHR1cvHgRjUZDUFAQwcHBf7Qki2EV0zUajTQ2NpKens62bdu4ePEiAIIgMHToUOLi4vjggw/w9va2RngxVnNzM7t27aK5uZm33noLLy8vq8W7HWq1ml27dvHdd9+xdOlSXnrpJavEMZlMNDY2cvDgQZRKJTU1NTg7OzN8+HDCw8OJiYlh1KhRDB061Crxb8aVK1eor6+nq6vrhtfc3d2JiIjAzu7/bA4gMQDq6+t5//33+eabb+jp6cHX15cHH3yQFStWMHr0aKv6xmBglbM8Ozubjz76iF9++QW9Xo/JZBJf6+zs5NixYxw8eJDk5GRrhAeuXcgbN27E29ubY8eOsXXrVl5//XWrxbsVarWaQ4cOkZaWhr+/P1OmTLFKHEEQ+OCDD9i2bRtXrlyhp6cHs9mMTCZDJpNhZ2eHp6cn4eHhLFiwgEceeYRRo0ZZRUtDQwMHDhygvb0dgPPnz5OXl0dra+sNx44ZM4aXXnqJFStWWEXLdbRaLSdPniQ/P5+zZ8/i7e3N6tWrmT59Ok5OTlaNfTOMRiMFBQV8/vnnHDlyBJlMRkBAABs2bODpp5/+QzRdJyUlhWeeeYbIyEg2b97MhAkTBiWuTqdDLpeL56Zer+e///0vO3fuFM+VZcuWDYoWa2BR0y0oKODQoUMcPHgQhUKB0WgEQCbrO4lnMBgwGAyWDH0DQ4cOJSIiArlcztWrVzl9+jS//vorUVFRVo3bH2q1mtzcXCoqKli7di3Tp0+3Spz9+/ezdetWamtrmT9/Pv7+/tja2gLXDEehUFBdXU1TUxO5ubkUFxezZcsWHB0dLa6loaGBL774gpqaGuCauZhMJnp7e284tqqqioyMDJYsWWK1bDctLY3t27eTn5+PWq3GYDBgY2NDRUUFjz32GA899BCjR4+2Suyb0d3dzffff8/WrVs5d+4c3d3dwDXTycnJ+UNNt76+ntzcXFpbW2lsbKSmpmbQTDc4OJjXX3+dl19+Gbj2fVy8eJHMzEx+/PFH3njjDQRBYPny5YOix9JY5Aw3m82cPXuWzZs3c+LECfGk/r3ZAmJ5YcmSJZYIfVOuZ3Zjxoxh6tSpFBQUcPToUcLDw5HL5VaN/XsMBgNarRZBEHB2dsbe3t4qcVQqFVqtlrVr1/Lcc8/h5+cnvmY2m9Hr9SiVSvbu3cvOnTvZt28fOp2Ot99+2+I1s97eXrq6utBoNLc9VqPRkJOTw/bt21m/fr1FdcC1m9F7772HQqG4obxRVFSETCbDz89v0E03NTWVL7/8kpKSEjEJkclkGI1G2tvbaW9vx93dfVA1AfT09JCXl8f+/fvFczYsLGzQ4tvZ2eHq6oqrqytwbQQ3cuRIIiMjmTZtGm+//Tbvvfces2bNYuTIkVbR0NHRQVpaGikpKdx3332sWLECrVaLra0tHh4eeHl50dHRQUdHx4DPG4uYbl5eHu+++y5ZWVl0dnZyq53LbG1tcXd3Z8SIEZYIfVPMZjNarRaZTMaUKVOIjo7m4MGDzJ07l8mTJ1s19u8pKSmhoKCAsLAwZs2aZbU4MTExJCUlsWbNGsaNG4eDg8MNx4wcORIfHx9GjBjB1q1bOXHiBN9++y2vvvqqVTLe32Jvb8+cOXOIiorCaDRSUlLCyZMnkcvleHp6MmbMGIvHbG9v54svvqCwsBCj0cjkyZNZuXIlVVVVpKSkiDcqnU5n8di3Q6lU0tTUhIODA/Pnz8fDw4O0tDQ6Ojqorq4mMzPzD8nm1Go1ZWVl1NTU4OTkRGhoKCEhIYOu4zoymQx7e3u8vLyIjY3FYDDw7LPP8sMPP/Diiy9aJeb+/fvZsmULpaWlFBcXk5aWRltbG3Z2dqJ/dXR0oNPpWLt27YBGJRYx3dTUVHJzc+ns7OzzfHBwML6+vhQUFIilBqPRSEVFBcXFxVY1v0uXLrF9+3YAVq1axf33309xcTGVlZWDbrpdXV10dXXh6Ogo3r2tQWhoKJs2bcLLy6tfwwWQy+WEhISQnJzM0KFD2bRpE6mpqcybN485c+ZYVEtiYiIpKSl0dHQAkJiYSHJyMiEhIRw/fpycnBxRU0BAgFVq3f/5z38oLCxEr9czZ84c/va3vzF79mwOHjzI2bNnUalUdHZ2cvXqVYvHvhVXrlyhtrYWQRBITk5mw4YNqFQqTCYTu3fvpra2lp07d1rNdLVaLWfOnCEnJweFQkFERASrVq0iMjKSrq4u2traMBgM+Pr6MnXqVKuNzgaCTCbDxcWFmJgY5s+fzzfffGM10/31118pLy9Hp9Nx+fJl6uvrxbkpuVyOnZ0dZrMZR0dHmpqaBvTZ92S6PT09pKenk5WVJU6YAAwfPpyEhASWL19OXV0dFRUV4klta2uLm5sbnp6e9xL6lvz66698/fXX5Ofns2jRIvz8/JgwYQKhoaFkZ2fz6KOPDtpJpFaruXz5Mj09PQQEBODm5ma1WPb29gQFBd32OFtbW3x9fQkLC8NsNtPU1ERhYaFFTdfDw4O//vWvLF68WLzh+vv7M3z4cE6fPs1PP/1EdXU1AMOGDSMiIgJnZ2eLxb9OQUEBGo2G8PBwNmzYQGxsLK6urkycOJExY8ZQWlqKwWBArVZjMpkGrfSUm5uLQqFAp9MxatQowsLCxIxfEAT0ej2lpaXs3buXpUuXWjR2VVUV3333HZmZmdTV1dHW1kZeXh7h4eFERkZSX19PdXU1MpkMb29vYmJiLBr/twiCQFNTE+7u7nc00pLJZAwbNoyxY8dy4sQJi+sxmUycPn2a4uJiurq6sLOzIzg4mPDwcDw8PHBzc6O6upqMjAzMZjPBwcED/n3u2nQNBgMHDx7k008/paysTLywvL29iYqKYuHChSQkJNDY2IhKpRJn1K/XRKzVN6tUKtm5cyfV1dUsW7aMpKQkXFxcCA0NZebMmfz0009UVlYSGRlplfi/p7S0lLy8PARBwM/Pz6qZ7kCQyWTiJBtg8ZuQjY0NISEhfYalhYWFHD58mCNHjvDLL7+g1Wrx9vbmkUcesVqNX6fTYTabiYyMZOLEieL3b29vL/7N7e3tZGdnM3nyZBITE62i47eUlJSwZ88eysvL8fT0xM3Nja6uLgoLCzlz5ox4nF6v59KlSxaPn5qayp49e6ipqaGnpwdvb29mz54t/laXL1+moqICQRAYMmQIw4cPt7gGk8lETU0Ne/bsobi4mNDQUJYtW0ZUVBRDhgy55XsFQaCnp8fimgAuXLjAV199RXFxMWazmdmzZ/PUU08RFRWFg4MDTU1NfPHFF5hMJlxcXFiwYAFjx44dUIy7Nt22tja+/fZbcnNzxVnXgIAAlixZQkJCAuPHj8fR0ZGgoCBWrVpFT08PO3fuRKPRIJfLrZJRGAwG0tPTKSgoIDY2luXLl4vm7urqyvjx40lPTyc7O3vQTLempoaamhrUajUtLS1otVqrZrt3Snd3Ny0tLZjNZoYNG8akSZOsFquzsxOFQsGOHTs4deoUjY2NGI1GgoODSUhIYN26dYSHh1stPlzLpq+brNFo5Pjx4+Tm5gLXzK2kpIRjx44xf/78217094pSqaSiogI7OzsSExOZPn06DQ0NnDx5kvPnz/c7AW0pjh07xo8//ohSqRRr7ElJScycOZPQ0FAAWltbaWhowN7eHicnJ4vX+pubm/nyyy8pKyvjwoULNDU1cebMGc6dO0dsbCwxMTFERET0e50IgiB2I1k6cbty5Qq7du3i2LFjqNVqwsLCWLp0KUlJSbi5udHY2EhaWpp4Y/T19SUpKQkbm4Et7L1r001JSSE/P180XICxY8eycOFC5s6dK5qqra0tISEhPPzww5w9e5aioqK7DXlbsrKySE1NJSgoiAULFvT5UWxsbPDw8GD48OGcPHmSdevWWf3igmsXtMFgoLOzk4aGBtRq9Z/CdJVKJSdOnMDW1hY/Pz+rTGLBtRLUtm3bSE9Pp7S0lJaWFlxdXXnwwQdZtGgR0dHRRERE9Mm6rUFVVRUFBQW0tbVRWVnJ3r17xXY2uNZBkZ+fzy+//EJsbKxVtZSXl3P16lWGDRvG5MmTGT16NCUlJbS0tGBjY0NUVBTu7u6Ul5dbPPZPP/1EeXk5BoMBV1dXxo0bx/Tp0wkPD8fBwYHLly9TWVlJV1cXrq6u+Pr64uHhYbH4ra2tbN68mTNnzjBlyhTi4uIwmUyUlJRQWloqmt6yZct45JFH+mTZZrOZhoYGcZRg6b7uo0ePkp6eTmtrKy4uLiQlJZGYmIibmxttbW1kZGSwa9cuLl++jLe3N48//vhdtdEN2HQ7OztJS0tj27ZttLW1IQgCMpmMcePG8fDDDxMZGXlDFiuXy/Hw8MDFxWXAAu+U6y0eGo2Gxx9/vN/MSS6XY2trS11dHXq9flBM9zoODg53XLeyNjqdDoVCQXZ2Nu7u7syZM8dqrUnFxcXs3r2b/Px84NqIIz4+nnXr1jFnzhyr19YTEhKorKyktLSUzz77DDc3N5RKJRcvXiQwMJBp06ah1+vJysqioqKCtLQ0q5uuQqGgpaWFKVOm4OvrK7ZIBQUFMWvWLJYtW4ZarbaK6drb24uZtF6vJy8vj+7ububOncuYMWNobGxEqVQC4OzsTEBAgEXP2c8++4xz586xfPlyFi5cSEBAAIIgUFNTw4ULF/j555/JyMhArVbj5OREYmIiTk5O9PT0oFQq2bdvH7t27cLHx8diJamenh727t3L119/TU1NDcOHDycxMZGVK1eKcyQajYbz589z/vx55HI5Y8aM4YknnrirlZ0DNt2ioiLef/99zp07J9Zxx4wZQ3JyMosXL+435TcajbS2toqz2JZGEAQOHDjAuXPnmDVr1i1XGNna2jJkyJABDwnuFVdXV0aNGjXofZdmsxmTyYTZbBafu3z5Mnl5eTQ0NBAVFUV8fLzFza+5uZns7GzS0tKora0Frg3xExISePrpp5kxYwb29vb09vai0+n6rFqUyWQ4ODjctANjIDzxxBO0tbWxc+dOcnJyxHbG8PBwMZtSqVR0d3eTnZ1NUVER1dXV4lDbGnR0dKDX6xk3bhwBAQEA+Pj4kJSUxNy5c5k6dSrHjx+3SuylS5eK/dpNTU3U1dVRWlpKUVERQUFBCIJARUWFeLwlSx1NTU1s376dTZs2sWLFij5JWFhYGGFhYeJk86FDh9i5cycjRoxgxowZlJWV8f3333P06FFcXFx48sknmTZtmkV0nT17lo8++kjsl46Li2Pt2rV9FlJ1dnaKqzz9/f1ZvHgxgYGBdxVvwKZbWlrKpUuXMJlM4g+ycOFClixZgr+/f7/v0Wg0lJaWinswWJpLly6xa9cuvLy8WLRo0S2/DEdHR4KDgy1yQQ+EESNGEBwcPKgLM3p7e1EoFCgUCtRqtfh8U1MTJ06cwNXVlZiYGKssClAqlbz77rviJKurqyuxsbGsX7+e2bNn09XVRWVlJS0tLZSVldHS0iK+18HBgWnTpjF79ux71uHn58fGjRtxc3OjoqIClUqFq6srcXFxxMXF4efnh1qt5qGHHiI/P5/q6mp++OEHXnnlFausjtNoNJhMJgRBwMvLS5zYc3Z2ZsaMGcC1G1ZlZaXFYwPMmjWLgIAAGhsbqa2tpby8nIqKCsrKyjh69CharVY8VqfTWTRROnz4MDY2NsTHxzNs2LB+j4mOjsbJyQkHBwcOHTrEJ598gkKhoKioiOPHjxMVFcVTTz3FokWLLHINm81mvvvuOyoqKjAYDEyZMoWVK1f22QuktraWAwcOkJeXh5ubG7GxsTz22GN3nbgN6Kzq7e29YfnukCFDiI6OvulmMo2NjRw+fJiMjAwaGxtv+mXfLdfXZV+9epXk5GQmTpx4U2PT6XR0dnYSGho66KvSgoKCrD5Z9Fu0Wi1lZWV8/vnnHDhwAJVKdcMxAQEB2NrakpOTg6enJ2FhYTe9cQ6U8vJyWltbxfPFy8uL6Oho4Fq7VFVVFYWFhSiVSoqKimhsbBTf6+TkxPLly5HL5YwdO5by8nLGjx9/150fgYGB/P3vf0ej0VBfX4+Xl1eflkV3d3fuv/9+Jk+eTE5ODhkZGTz//PMWP1fhWudCW1vbTTNIk8lEWVkZqampFo99ncDAQAIDA5kxYwY6nY7GxkZOnTrF2bNnOXPmDBcuXBBHsZZkx44dREdH4+npecsMOigoiLi4OBQKBSkpKfz88894enoyb9481q9fT1xcnMU0mc1mcUn4uHHjeOGFF1iwYIE472IwGDh+/Djbtm2jsbGRadOmsWbNmnuaxBuQ6apUKkpLS/sMBcPDw8W61HV6e3vRaDRotVoyMzP5xz/+QUdHB/b29vj4+Fi0BaWuro69e/cSHR3NpEmTblpWMJlMKJVK6urqePTRRy0W/1Zcv0mZTCaGDBkyaDt7mc1msrKy2LFjB3l5eXh6ehIREUFLSws1NTUYjUbs7OxoaGjgk08+Yffu3fj7+/Pyyy9bbHLixx9/7GP0Go2GlJQUMjIygGvZXF1dXb8Xt8Fg4NixY9TU1LB06VL27NnDunXrWLVq1T1pcnFxISIi4obnZTIZQUFBTJ8+naysLFQqFSqVyiqmW1xcTFtb201fb21t5dSpU5w/fx4PDw98fHwsruG3ODo6Mnr0aEaPHs3SpUv56KOP+PDDDzEYDMjlcovWcxsbGwkICOh3/43rZabOzk4qKirIzMwUJzqHDh3KzJkz2bRpk8X3f7Czs2P9+vViB09cXFyfm/uFCxc4deoUKpWKoKAgli5des99ywMy3draWvbv39+nY2Hq1KkEBQVhb29PZ2cnbW1tXL16VVzRUVlZSUdHBzKZjLCwMJYtW8bUqVPvSfRvKS8vR6PRMGfOnJueoAaDgaqqKoqKivD29rboIoBbcb1joaurC3t7+0HZutBsNlNZWcmHH36IQqFgxowZrFixgpEjR5KSkoJWqxUnSHQ6nbioZejQoRbVV1VVhU6nEzOa5uZmmpubb3q8ra0tzs7OODo60tzcjFKpRKlUkpeXh1wuJyMj455N91Y4OjoyYsQI7Ozs0Gq1FBcXW3w/Cr1eT319PTqdDgcHBxwdHft85z09PZSXl5OZmYnJZCIwMJB58+ZZVMOt6O3tFWv/MpkMNze3u65b9seTTz7Jli1byM/PJyQkBFtbWwRBEPcwaGho4NKlSxw7doyysjJxVzw7OzvGjx9vtc2qYmJi+jXS9vZ2MdO2sbFh3rx5rF69+p5vRAO+yn5fx8jKymLSpElotVoqKys5efIk1dXVVFVVUVdXh52dHW5ubkRGRrJixQpWr15t0Yyvvb0duVyOi4tLvyUDjUaDQqHg559/pqysjIULF1olg+kPjUZDXV0dJpOJkSNHDsrCiM7OTt59912Ki4sJDw/n+eefx93dnY8//pgjR44QGRnJunXrmD9/vjjchmudHYO94Yu9vT1ubm64uroydOhQoqKiCAkJYc+ePeLE6/WtMP/9739bVcuwYcMICgrC2dkZg8EgzuBbErVaTU5ODs3NzQQHB+Pv79/nAm5ubub06dOcO3cOPz8/lixZIk60DQY9PT3IZDLkcjmCIODm5nZHKxzvlFWrVrFv3z7ee+89goKCsLGxwWQyUVBQQGtrK11dXeKK1ZkzZ3L//fdTX19vlZVnt8NoNHLkyBHS09Pp7u4mJiaGxMREi1zDAzJdJycnAgIC6OzsFFeEVFVV8a9//QsnJyd0Oh1qtRq9Xg9cu1v6+PiwZs0aNm7ciLu7u8WH2I6Ojjg5OVFVVUVFRQVeXl7Y2tqKu2plZ2eze/duTCYTS5YsYeXKlRaNfyuMRiO9vb1ixmCNlT2/p6WlhTNnzqBWq4mMjESr1fLNN9+QnZ3NxIkTSU5OZtGiRbi4uODt7W21/ly5XC5muUOGDBH7cOVyOa6urjg6OuLt7U18fDwPPPAA9vb2+Pn54enpyezZszl8+DBHjx7l/fffZ/To0RarNZvNZsxmMzY2Nn0SiM7OTmpra8VRibU2wLnePWEymTAajWJmaTAYKCkpITMzE5lMxl/+8hdeeOEFq2i4GW5ubowcORIXFxdx3+NbbV41UAICAvjkk0945513qKyspLu7W9yvJTAwEHd3d8aPH8+sWbOYNm0aPj4+ZGZm0tjYaNUFI/1RXl7O/v37KS0tJTQ0lFWrVvHggw9a5LMHZLqRkZE899xzvP766302o75eA+vzwXZ2ODo6Ehsby2uvvWa1TaKTkpI4fvw4W7ZsYfv27YwfPx53d3d6enqoqalBpVIxb948nnnmGSZMmDCoP55Sqex30+7BIiMjg71799LT08NDDz3Eiy++aLW9fH9PdHQ0V69exWAwMHHiRLFVLjAwkMWLFzN16lRx+83fj55iY2OZM2cOb731lkVb2cxmM5cvX+bKlSsMHz4cb29v8Qbd0NBAVVWVWH+3lMn/FplMJpp9Q0MDZWVl1NbWMmLECMrLy9mxYwfnz58nPj6eZ555xuLxb0dHRwdVVVU0NDSI/feWXrQybdo09u3bJ8ZTqVSYzWb8/f37Tciub3Jj7V0Jf4tKpWLPnj1kZ2fj7OxMQkIC8+fPt9jnD7i8MH36dKZOnUpWVlaf2i5cy2gcHBywsbHB39+fBx54gMTERKs2wDs6OvLGG28wYcIECgsLKSwsRKVSER0dzRNPPMHMmTMJCQn5Q/4zwM1MxZrY2tri4OCATCZDqVTi4+NDYmIiGzduHNTd1d5++21mz56NRqNhyZIlAx6m2traWvyCb25u5p///Cf79u3jvvvuIz4+Hg8PDzo6Ojh79izZ2dnIZDJGjBhhlU1erm9IVFZWRmdnJz/88ANZWVnI5XL0ej16vZ6EhAReffVVJk6caPH4t6O7u5v29nb0er24+Y61/rMIXMusb7c6Mz4+nvj4eKtp+D1arZaPP/6YHTt20NHRwYsvvsgrr7xi2f56QRBu9eiXS5cuCTNnzhQ8PT0FNzc38ZGUlCTs2rVLyM3NFUpLS4W2trabfcStuGMdVsYiOr7++mshKSlJ+OqrrwZFh9FoFN555x0hLCxMCAgIEF577TWhurr6bmPftQ4rctc6cnJyhBkzZghAvw+5XC6MHz9e2LJli6DX662i49KlS0JSUpLg6Ogo2NjYiI+xY8cKb775pnDx4sU7/XNupeOufpv6+nrhueeeE4YMGSKEhYUJn376qSW0/BHclY7e3l4hNTVVmDRpkmBnZycsX75cyMvLs7SOuzNdKyPpkHRYRUd1dbWwYcMGwdfXV3BwcOjzcHd3Fx599FGhtLTU6jpSU1OF6dOnC87OzuJjzZo1QlVV1Z1+xO103NVvYzQahczMTGH9+vXCm2++KVy9etUSWv4I7kqHQqEQZs2aJdjY2Ajz5s0T0tPT7+TmO1AdyIRbF8otV0W/c/oruko6+iLp6Iuk40b+LFr+V+hoamrinXfeYffu3eh0OjZv3szq1avvdRFVv7/N4G5AICEhIfEn5PDhw6SmptLa2kpycnKfnRItjZTp3hxJR18kHX35M+uAP48WScfvn7yN6UpISEhIWBCpvCAhISExiEimKyEhITGISKYrISEhMYhIpishISExiEimKyEhITGISKYrISEhMYj8D/Umt44n+RMkAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 10 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# We need matplotlib library to plot the dataset\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Plot the first 10 entries of the labels and the dataset\n",
    "figure = plt.figure()\n",
    "num_of_entries = 10\n",
    "for index in range(1, num_of_entries + 1):\n",
    "    plt.subplot(6, 10, index)\n",
    "    plt.axis('off')\n",
    "    plt.imshow(dataloader.dataloader1.dataset.data[index].numpy().squeeze(), cmap='gray_r')\n",
    "    print(dataloader.dataloader2.dataset[index][0], end=\" \")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implement PSI and order the datasets accordingly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute private set intersection\n",
    "client_items = dataloader.dataloader1.dataset.get_ids()\n",
    "server_items = dataloader.dataloader2.dataset.get_ids()\n",
    "\n",
    "client = Client(client_items)\n",
    "server = Server(server_items)\n",
    "\n",
    "setup, response = server.process_request(client.request, len(client_items))\n",
    "intersection = client.compute_intersection(setup, response)\n",
    "\n",
    "# Order data\n",
    "dataloader.drop_non_intersecting(intersection)\n",
    "dataloader.sort_by_ids()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check again if the datasets are ordered"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 6 7 3 9 1 8 1 9 8 "
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAAqCAYAAAAQ2Ih6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZjUlEQVR4nO2deVDU9/3/H3uwLDesyCFXBFS8OCQGjKJgjUdMTNRUw9R4JGma1Gbapp1MjqbppJ00M03S1k6bxCOxxjjGxCZE0RiVu3ggAQS5RUCQGxZYYJf9sJ/fHwyfX4gnurv4TT+PmZ3RXdjXk8/x3Pf79Xq936sQRREZGRkZGfugHG8BMjIyMv9LyKYrIyMjY0dk05WRkZGxI7LpysjIyNgR2XRlZGRk7Ij6Jq+PR2uD4hrPyTpGI+sYjazjau4WLbKO7yGPdGVkZGTsyA/edPv6+tixYwcxMTG88sor4y1HRkbmf5ybpRf+T2OxWDhy5Ajbtm1Do9EQHx8/3pJk7hJqamrYt28fRqORDRs2EBERMd6SZP5H+EGbbnp6Ojt27KCtrY3NmzezdOnS8ZYkcxdQUVHBW2+9xcGDBxEEgerqat544w2mTp1qdy05OTmkp6fT0dFBd3c3VVVVTJo0iTfffJPw8HC76xnhyy+/ZM+ePfzkJz9h7dq146bjh4hNTLe7u5s9e/aQkZHBE088waOPPmqLMDekoKCAPXv2cOrUKaKjo1myZAlardbuOuyNXq/n7bff5uDBg3R2dqJQKBBFkYkTJzI0NDTqOS8vL6Kjo4mIiCA+Pp7ExMQf/DHq7Oxk+/btpKamYjAYUKlUKJVKnJ2d7a7l+PHjvP322+Tm5iKKIkNDQ1gsFnQ6Hbt27eLPf/6z3TWN4Ovri1qt5vjx4//TpltSUsKePXs4fPgw7e3tKBT/vzam0+l46aWX2LRp09jeVBTFGz1ui+zsbHHVqlWiVqsVw8PDxT/84Q9j+XWr6Hj99ddFd3d3UavVis8++6xoMBjG+ha3rGNwcFDs6OgQd+/eLcbHx4u/+93vxF/+8pfivffeK4aEhIjBwcFicHCwOHfuXPGdd94R+/v7baJDFEXxhRdeEP39/UW1Wi0qFArp4eDgcNVzarVadHJyEj08PMSAgAAxOTnZajpuhl6vF7Ozs8Xf//734vz588Xly5eLxcXFt/Krt62jqKhITE5OFnU6nahUKkWFQiFqNBpx69atoiAIY/0T7vh4vPXWW2JAQIDIcGVdeiiVSjEqKko8evTo7eq4o3MjiqJ44MABMSgoSExKSrrV83I9LePBHevIy8sTn3/+eTEiIkL08PAQ1Wr1VedJq9WKy5cvF5ubm8eiw/oj3ZKSEg4ePEheXh5Go5H29nY6OjqsHea6iKJIdnY2p06dore3l6SkJNavX4+Li4vVYwmCQG1tLUVFRbzzzjvU1tbS1dVFdXU1Q0NDDAwMIAgCFosFgNbWVoxGI0qlks2bN+Pp6Wl1TR4eHgiCgK+vL+7u7kRHR+Pu7o6Liwuenp54enoyefJkBEGguLgYg8HAmTNnyMzMJD8/3+p6vovZbKaiooJ9+/aRnp6OXq+ntbWVvr4+HBwceO211/jwww/x8vKyeuyioiLefPNNvvnmG7q7u0e9plQqUalUVo95MwYHBxEEAYVCgVarxcnJCbPZjMFgwGQyMTAwYHdNAD09PTQ0NNDb24uDgwOOjo52iWs0GiktLeX9998nKysLBwcHoqKiePLJJ1m8eLFdNAwNDVFWVsbf/vY3aTY0NDQEMGqUC2Aymbhw4QIVFRX4+vrecgyrm+6lS5fIz8+nra0NgNDQUJYsWWLtMNdlYGCAzz//nNzcXCZOnMj999/P7NmzbRIrMzOTV199ldbWVpqamqST892bWqFQoFQqsVgsDA0NUV1dzf79+wkICODHP/6x1TUlJycTGRmJj48Pnp6euLu7S1NotVqNSqVCq9UiiiIJCQnU1tZSV1cHDJuALejv76e0tJRDhw5x8uRJKisr6e7ulqbUMGzIBQUFpKSksHnzZqtrOHToEKdPn6anp4eQkBBMJhNtbW2Ehoba7Pq4EY2NjZSVldHT08PKlSvZvHkz06ZNw2Kx0NraSnl5OW5ubnbXZbFYKCwsJC0tDVEUCQgIICQkxKYxRVEkJyeHvXv3cubMGWpraxEEAZVKxeXLlzEYDBiNRoKCgpg+fTpqtW1KUYIgkJKSwo4dO6RrBUCtVuPn58eiRYvw8fFh9+7ddHV1oVAoUKvVaDSaMcWxuvrm5maam5sRBAGtVktYWBgzZ860dphrMjg4yGeffUZGRgYGg4GHH36YdevW2WTkBPDKK69w/vx5yThuBUEQuHLlCgUFBSxYsAB/f3+rapo8eTIBAQFoNJqbXpxtbW188cUXpKeno9VqmT59ulW1wPBoYM+ePRw4cICysjL0ej0mkwkYPXIQRRGVSoWPj4/VNQA4OTmhUCiIi4sjMTGRM2fO0NbWhpeXF35+fjaJeSNGRrKOjo7ExMQwb948rly5QmdnJ0lJScTExNjMXG5Ee3s7aWlp5OTkEBgYyMKFC8dsKmPBYDDw5Zdf8tFHH/Htt98CsHr1alasWEFvby+HDx/mzJkzVFRU4O/vzxNPPMGWLVtsoqWhoYFPPvmErKwsjEaj1PGUlJTE7NmzmTFjBoWFhXz++ed0dXXh6OjIjBkzmDVr1pjiWP2s9vX1SdMiQRDo6emhr6/P2mGuSVNTE6mpqVRVVTFt2jQWL17MlClTUCpt045cWVkpTQ/HQmdnJ1lZWTYxXbVafdObtaGhgfT0dNLT0zl+/Dg9PT3Mnj2b3/zmN1bVArBjxw527dpFeXn5DUfSrq6uREdHExMTY3UNAI8++ij33HMP3t7e+Pn5odfryc3NvaXjZQv8/f0JDw8nNzeXsrIydu7cKRXUDAYDa9assbsmGB64dHV14erqytKlS0lISLBpvGPHjvHBBx+Qn5+Pr68vTz31FGvXriU4OBiz2YxOp6Ojo4OcnBwaGhrw9va2mekePXqUwsJCjEYjKpWKxYsX8+yzzzJ37lwcHR0pKipiz5490ize1dWVRYsW4erqOqY4Vr/aRqbRMDyScXBwsNtFnZaWRnFxMUajkYULF7Jw4UK75aNgeJSp1Wqpr6+/4QfN4OAgNTU1ZGZmEhUVRUBAgM21mUwmmpubycvL48iRI5w9e5aWlhYUCgUrVqwgOTmZ+fPnWzXm4cOH2bt3L2VlZaMMd926dcybN4/U1FROnDgBQFhYGC+//LLNRrqTJ09m0qRJqNVqlEolOp1uXMx2BCcnJ0JCQnB1dSU7O5u8vDxaW1vx8PDg3Llz42a6NTU1VFVV4enpycyZMwkODrZZrO7ubj799FOKiooQRZGNGzeyceNGgoKCUCgUWCwWJk2ahLe3NwAqlcomtZkRLXl5eXR1dQEwa9YskpOTSUpKQqPRkJ2dzd///neys7OlUXBYWBgPPPDAmGPZ9KoLDw9n5cqVBAYG2jIMABcvXuTQoUPU1dURHBxMXFwcISEh0ijUYDBw6dIlSkpKOH78OBqNhlmzZvGLX/zitmNu2rQJo9EIQH19PV1dXQwNDd3SyFev1/P111+j0Wgks16zZo1VW7ZKSkro6OigpaWF8vJyysvLqaiooKamBpPJxJw5c9iwYQOxsbFERERYNXZlZSUfffQRpaWlkuEGBQWxdOlSnnrqKRwcHMjNzZV+3snJiaCgIJsVtJRKJU5OTtL/VSrVmGco1mbp0qVkZWXx2WefSc85OTmNawEtJyeH/Px8YmNjmTt3Lg4ODjaLd+XKFSorKxkYGOChhx5i1apVBAYGolAo6Ovro6CggAMHDlBQUACAj48Py5cvt4mWtrY2Ll68SH9/PwATJ04kJCQEtVpNWloa//rXv8jIyKCvrw9HR0eio6N58cUXmTZt2phj2dR0AwMDiY2Nxd3d3ZZhgGGDqaqqwmw2s2TJEuLi4nB0dMRkMlFcXMzhw4c5f/48ly5dorCwELVazbRp07jnnnt46KGHbivm1q1bpc6EtLQ03nvvPVpaWiQjvhGCIFBVVcXevXtxdXVFp9Px4IMPWsX4cnJySEtLo6CgAIPBQGdnJ01NTXR1dWE0GlEoFEydOpV169axZs0am4wuv/nmG/Lz86URv7u7OwkJCaxatYrW1la+/vprCgoKJOPr6uri9OnTPPLII1bXcj2+28YzHvj5+aHT6aRjMGXKFFasWEFSUtK46CkvL+fs2bO0tbWh0+lsPljq6OhgYGAAb29v1q5dy7Rp01AqlXR0dHDs2DEOHDhAYWEhnZ2dqFQqfH19iY2NtYmWgYEBenp6MJvNwPAg7ujRo+Tl5XHs2DFOnTo1ynBfeOEFVqxYcVv3q1VNt6Ojg6amJumTWq1W22V6L4oiRUVFdHR0oNVqiYuLY8qUKQwODpKdnc327dvJzs4Ghi/sVatWUVlZycWLF9m5cycrVqy4rRHWlClTpH9XVFTQ19cnTU9uBUEQaGxsBMDFxYWPP/6Y559/fsw6vk9GRgbvv/8+TU1NKBQKJkyYgJ+fH76+vjQ3N9PS0kJ/fz/FxcV4eXnxwAMPWL2YVF1dTW9vr2RoFouFhoYGDh48SF1dHUVFRfT29ko/39/fz6VLl6yq4Vbw9PS0WUrjZmRlZXH+/PlRWubPnz9upnvx4kXq6upQKBR2mQWMtFNOnz6dmTNnolAoqKysJC0tjQMHDtDV1UVCQgLt7e2cPXuW2bNn26zoOWHCBCIjI6mtraW7u5vLly+zf/9+RFGkubkZs9mMh4cHc+fOZcuWLTz00EO3PUCyqum2tLRQV1dHT08PSqUSR0dHm05PRhAEgcLCQrq6upg1axahoaEIgkB+fj67du3i2LFjhIWFsXLlSubNm4dWq2Xv3r1cvHiRlpaWcRvpfBdRFNHr9VZ5r4iICBITE4HhaVJAQAD+/v6o1WqampooKioiLy9PyqcZDAaefPJJq35ABgcHj6p69/X1kZ2dTXZ2tnS8v3tjOzk52SUNBcOpHb1ez9DQEAEBAYSFhdkl7ncxGAycPHmS0tJSRFFEoVDQ2NhIVVWVTVoJb4X6+nquXLmCh4cH/v7+Nl+d6O7ujrOzMwMDA2RlZZGVlUVpaSllZWV4enqyfv16wsLCSElJobm5mYcffnjMRatbxcfHh40bN9LW1kZOTg4Gg0FqpQSkLpNf//rXJCYmjkpVjRWrmm5vby8GgwGLxYJWq5UOqq0ZGhqSFh4kJCQQFhZGQUEB27dv58SJE4SEhPCzn/2MNWvW4OXlRWlpKZ2dnTg7O5OUlGSVPGJISAj33Xcf/f399PX1YTabGRoaQhRFLBbLTaexTk5OY19OeB0efPBBQkND8fDwICAg4Kqb5/Llyxw+fJgvv/ySc+fO8cknnxAVFcX9999vlfgAS5Ysobi4mJSUFPR6/ai/3dfXF09Pz1ELZ7y8vLjvvvusFv9G1NfXc/nyZcxmM25ubqNaCkcWtWi1WpsW2ioqKqioqAAgJiYGi8VCdXU11dXVNot5M1paWmhtbSU+Pp74+Pg7MpZbITg4mAcffJCjR4/ywQcfIAgCLi4uxMTE8PjjjzN//nzOnj3LhQsXmDhxok2vD7VazcKFC8nNzaWkpASDwXDV6xERESxduvSO/cKqV1VtbS3Nzc3AsImMNOjbE39/fywWC6mpqaSmpqJUKlm2bBmPPfaY1PZx8OBBMjMziYyMZOPGjVaZSkVFRfGrX/2KKVOm0N7eTmlpqZSYt1gsmEwmqT/1elirMuvs7MycOXOu+3pQUBCbN28mKCiI1157jYqKClJSUqxqupGRkbz44ou4ublx+fLlUcd4xowZaLVaUlNTJdN1cnKyaaX8u4y0MX63v3pkpnH+/Hlqa2uJjIy0WfsaDG/GVFZWxowZM3jmmWcYHBzkvffeo6uri/b2dqliby/0er1kNCPtbLZOMXh7e7N161YCAgIoKirC3d2dyMhI5s2bx+TJkzGbzZSXl1NQUEBcXJxN+4VhuJXyypUr16zJCIJAfX09hYWFzJkz546OjdVM12KxUFxcLA3J3dzc8Pf3x8PDw1ohrotSqcTd3R2NRsPFixdJTU3l1KlT6PV6EhMTWbBgAW1tbZw7d44jR46QnZ1NVFQUTz/9tFW39IuLiyMuLg4YLiT997//paenB5PJREFBwVXLbJVKJS4uLmg0Gry9ve1aTR9pWQoODqa6upqamhqrx5g+fTp/+tOfpOLECBqNhoyMDI4cOYJCoUClUtm1ta+rq0vKJ3d2dtLS0oKzszM5OTm8++671NbWsmHDBpua7sgy+UceeYSVK1diNBqpqqrixIkTZGZm2n2TmYaGBtrb29Fqtbi4uNglLQjDxfbrdRAZDAY6Ojqslna7EfX19ezevZvU1FTa29txdnZGp9OhUqnQ6/V0d3eTm5vLP//5T15//fU7WqVnNdPt6Oigvb1dusGcnZ3tYrgADg4OREdH8+233/Kf//yHkydP0t7eDgznDSsqKvjiiy84ceIEvb29xMbG8vLLL99Wj92tsnTpUmkryb6+Pt59992rTNfFxYXFixcTFhaGi4vLHefQ2tvbpV2qbjY1NplMtLa20tHRgUqlstmS02t1rlRXV3P8+HGKi4uB4dSCvfazFQSBgoICqWh39uxZdu/eTWhoKPv37yczMxNPT8+rppfWxGQyYTabSUpKYtmyZUycOJGWlhapxzsnJ8fupjuSAvPy8iIyMnJct5UcwWg00t/fLxXkbTUoGRwcZN++fezatYsrV64AcO+997Js2TJcXFykRUR6vZ7MzEwOHTrE1q1bb1uP1Uy3urqayspKaQTh4eGBTqez1tvfEIVCwaJFi8jLy+P06dO0t7dLOcTMzExOnTqFq6sr/v7+JCUlsWnTJrvuB3E9JkyYwKZNm267Ze377N+/H6VSybp16244PR0YGKCoqIh9+/Zx6tQpQkNDWbBggVU03ApHjx7l008/pb+/HycnJ2JiYuxiMoIgUFZWRklJibQ/xoULF3jrrbfQarW0trbi6urKjBkziI6OtpmOuro6ent7CQ8Px83Njc7OTi5fvixdt4IgMDQ0ZNdNeKqqqmhsbMTBwQEnJye7jXRvRGNjI9XV1bi5uREUFGSz9EJDQwPffPONlOry8fHh6aef5vHHH5da1UZ63Ht7e6mpqUEQhNs+RjarFJhMJrst/wVITExErVazc+dOTp8+TUtLi5SYnzx5MsuXL+fRRx9l+vTp47KRiD34+OOPCQ4OZsWKFdc13d7eXrKzs3nvvfc4ceIEOp2OhQsX8thjj9lFo9lspqmpSVpKec8995CcnGwX029sbOSvf/0rx48fl/J2I7lcjUaDv78/Cxcu5MUXX2TGjBk201FXVyetgPL398dsNpOenk5+fj6CINDb28vAwIDNKvXX4vTp05SXl+Pn53fX7KlsNBoxGAwEBgYSHx9vM9MdKayaTCbUajXr168nOjqa/v5++vv76ezsHLUxU39//x11PNnEdB0dHVEqlWPaCMYaMX/0ox8xZ84cMjIy+Pjjj2lsbCQpKYnk5GSioqLspmW8CA0NpbCwkAsXLuDu7o6npydDQ0OYzWYEQaCvr4/Dhw/zj3/8g5KSEjw9PUlMTOTpp5+2ywIWGK7aX7p0SWqT0ul0dmvZOnPmDPn5+XR3d0s7ro3sEhUWFiZtpmJrswsJCcHDw4OTJ0+SmZk56jVfX18CAwPtargj14darSY2NtZmCxDGisFgwGAwMGXKFJumAr/bl6xSqdBoNJSWlpKWlkZubi5nzpyRGgR8fX2ZN2/eHX0AWM10HR0d8fb2xsXFBR8fH376059arQVqLHh5ebF69WpWr15t99jjzdSpUzl58iS//e1vWbVqFcnJyTQ3N1NdXU1jYyMFBQVSP7ObmxsLFizgjTfesMnuYtfCZDKRnp7O2bNnpQKaWq2221Q2LCyMkJAQWltbmTRpEgEBAQQGBhIXF0d8fDzh4eF2mdJPnToVPz8/1Gq1VANRKBRoNBpiYmLs/k0rFRUV1NXV4eLiwvz5821aQLxVzGYzVVVVlJWVER4ePqb9aseKu7s73t7e1NfXYzKZeOedd676GYVCgaOjIzqdjokTJ95RPKuZ7pw5c3j11VelzbHHe1373c7ISbRmL2hCQgIZGRl8++23vP322/zlL38BkPbSHdmsOywsjC1btvDcc8/ZNdUysgpupIg1YcIE5s6da7etP2NjY/nqq6/sEutmrF69mvr6eqmYqNVqiY+P5+c//zn33nuvXbV4eHgQFRVFcHDwDVsN7UlDQwNlZWVXbThvC2JiYnjppZf44x//SEVFBYIgSHl1BwcHHBwc8PLyYtGiRTzzzDN3vPOaVdMLfn5+3HfffWg0GiZMmGDNt/4/j0qlQqVSSRvijCzYsGYuc8mSJeh0Ov7973+TkpIifadTeHg4s2bNIiAggFmzZrFo0SKCgoKsFvdWGWmRG9lEfc6cOTZdZXQ3s2bNmnHbSez7BAUF8frrr4+3jHFl1apVeHt7c+HCBfr6+khJSaGpqYlFixYRGxtLdHS09RZnXO97fG7ne4WsxA9SR319vfjyyy+LWq1WjIiIELdt2yYajUa767gDrKLjq6++EhMSEsTZs2eLO3fuHDcdVuBu1nE3abktampqxOeee07UaDTili1bxk3HHXLNc6MQb1yFG49NCa6Vl5B1jEbWMRpZx9XcLVpuS0dbWxvbtm1j27ZtrF27lg8//HBcdNwh1zw3suleH1nHaGQdo7mbdcDdo0XW8f0nb2K6MjIyMjJWxDZfHiYjIyMjc01k05WRkZGxI7LpysjIyNgR2XRlZGRk7IhsujIyMjJ2RDZdGRkZGTvy/wDA3Qi2P9HrPwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 10 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# We need matplotlib library to plot the dataset\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Plot the first 10 entries of the labels and the dataset\n",
    "figure = plt.figure()\n",
    "num_of_entries = 10\n",
    "for index in range(1, num_of_entries + 1):\n",
    "    plt.subplot(6, 10, index)\n",
    "    plt.axis('off')\n",
    "    plt.imshow(dataloader.dataloader1.dataset.data[index].numpy().squeeze(), cmap='gray_r')\n",
    "    print(dataloader.dataloader2.dataset[index][0], end=\" \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "# Define our model segments\n",
    "\n",
    "input_size = 784\n",
    "hidden_sizes = [128, 640]\n",
    "output_size = 10\n",
    "\n",
    "models = [\n",
    "    nn.Sequential(\n",
    "        nn.Linear(input_size, hidden_sizes[0]),\n",
    "        nn.ReLU(),\n",
    "        nn.Linear(hidden_sizes[0], hidden_sizes[1]),\n",
    "        nn.ReLU(),\n",
    "    ),\n",
    "    nn.Sequential(nn.Linear(hidden_sizes[1], output_size), nn.LogSoftmax(dim=1)),\n",
    "]\n",
    "\n",
    "# Create optimisers for each segment and link to them\n",
    "optimizers = [\n",
    "    optim.SGD(model.parameters(), lr=0.03,)\n",
    "    for model in models\n",
    "]\n",
    "\n",
    "# create some workers\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "\n",
    "# Send Model Segments to model locations\n",
    "model_locations = [alice, bob]\n",
    "for model, location in zip(models, model_locations):\n",
    "    model.send(location)\n",
    "\n",
    "#Instantiate a SpliNN class with our distributed segments and their respective optimizers\n",
    "splitNN = SplitNN(models, optimizers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(x, target, splitNN):\n",
    "    \n",
    "    #1) Zero our grads\n",
    "    splitNN.zero_grads()\n",
    "    \n",
    "    #2) Make a prediction\n",
    "    pred = splitNN.forward(x)\n",
    "    \n",
    "    #3) Figure out how much we missed by\n",
    "    criterion = nn.NLLLoss()\n",
    "    loss = criterion(pred, target)\n",
    "    \n",
    "    #4) Backprop the loss on the end layer\n",
    "    loss.backward()\n",
    "    \n",
    "    #5) Feed Gradients backward through the nework\n",
    "    splitNN.backward()\n",
    "    \n",
    "    #6) Change the weights\n",
    "    splitNN.step()\n",
    "    \n",
    "    return loss, pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Pavlos\\anaconda3\\envs\\pyvertical-dev\\lib\\site-packages\\syft\\frameworks\\torch\\tensors\\interpreters\\native.py:156: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.\n",
      "  to_return = self.native_grad\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 - Training loss: 1.146 - Accuracy: 72.833\n",
      "Epoch 1 - Training loss: 0.385 - Accuracy: 89.170\n",
      "Epoch 2 - Training loss: 0.317 - Accuracy: 90.945\n",
      "Epoch 3 - Training loss: 0.281 - Accuracy: 91.982\n"
     ]
    }
   ],
   "source": [
    "for i in range(epochs):\n",
    "    running_loss = 0\n",
    "    correct_preds = 0\n",
    "    total_preds = 0\n",
    "\n",
    "    for (data, ids1), (labels, ids2) in dataloader:\n",
    "        # Train a model\n",
    "        data = data.send(models[0].location)\n",
    "        data = data.view(data.shape[0], -1)\n",
    "        labels = labels.send(models[-1].location)\n",
    "\n",
    "        # Call model\n",
    "        loss, preds = train(data, labels, splitNN)\n",
    "\n",
    "        # Collect statistics\n",
    "        running_loss += loss.get()\n",
    "        correct_preds += preds.max(1)[1].eq(labels).sum().get().item()\n",
    "        total_preds += preds.get().size(0)\n",
    "\n",
    "    print(f\"Epoch {i} - Training loss: {running_loss/len(dataloader):.3f} - Accuracy: {100*correct_preds/total_preds:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Labels pointing to:  (Wrapper)>[PointerTensor | me:88412365445 -> bob:61930132897]\n",
      "Images pointing to:  (Wrapper)>[PointerTensor | me:17470208323 -> alice:25706803556]\n"
     ]
    }
   ],
   "source": [
    "print(\"Labels pointing to: \", labels)\n",
    "print(\"Images pointing to: \", data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
